import json

from holmes.common.env_vars import SSE_READ_TIMEOUT
from holmes.core.tools import (
    ToolInvokeContext,
    Toolset,
    Tool,
    ToolParameter,
    StructuredToolResult,
    StructuredToolResultStatus,
    CallablePrerequisite,
)

from typing import Dict, Any, List, Optional, Union
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.stdio import stdio_client, StdioServerParameters

from mcp.types import Tool as MCP_Tool

import asyncio
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field, AnyUrl, model_validator
from typing import Tuple
import logging
from enum import Enum
import threading

# Lock per MCP server URL to serialize calls to the same server
_server_locks: Dict[str, threading.Lock] = {}
_locks_lock = threading.Lock()


def get_server_lock(url: str) -> threading.Lock:
    """Get or create a lock for a specific MCP server URL."""
    with _locks_lock:
        if url not in _server_locks:
            _server_locks[url] = threading.Lock()
        return _server_locks[url]


class MCPMode(str, Enum):
    SSE = "sse"
    STREAMABLE_HTTP = "streamable-http"
    STDIO = "stdio"


class MCPConfig(BaseModel):
    url: AnyUrl
    mode: MCPMode = MCPMode.SSE
    headers: Optional[Dict[str, str]] = None

    def get_lock_string(self) -> str:
        return str(self.url)


class StdioMCPConfig(BaseModel):
    mode: MCPMode = MCPMode.STDIO
    command: str
    args: Optional[List[str]] = None
    env: Optional[Dict[str, str]] = None

    def get_lock_string(self) -> str:
        return str(self.command)


@asynccontextmanager
async def get_initialized_mcp_session(toolset: "RemoteMCPToolset"):
    if toolset._mcp_config is None:
        raise ValueError("MCP config is not initialized")

    if isinstance(toolset._mcp_config, StdioMCPConfig):
        server_params = StdioServerParameters(
            command=toolset._mcp_config.command,
            args=toolset._mcp_config.args or [],
            env=toolset._mcp_config.env,
        )
        async with stdio_client(server_params) as (
            read_stream,
            write_stream,
        ):
            async with ClientSession(read_stream, write_stream) as session:
                _ = await session.initialize()
                yield session
    elif toolset._mcp_config.mode == MCPMode.SSE:
        url = str(toolset._mcp_config.url)
        async with sse_client(
            url, toolset._mcp_config.headers, sse_read_timeout=SSE_READ_TIMEOUT
        ) as (
            read_stream,
            write_stream,
        ):
            async with ClientSession(read_stream, write_stream) as session:
                _ = await session.initialize()
                yield session
    else:
        url = str(toolset._mcp_config.url)
        async with streamablehttp_client(
            url, headers=toolset._mcp_config.headers, sse_read_timeout=SSE_READ_TIMEOUT
        ) as (
            read_stream,
            write_stream,
            _,
        ):
            async with ClientSession(read_stream, write_stream) as session:
                _ = await session.initialize()
                yield session


class RemoteMCPTool(Tool):
    toolset: "RemoteMCPToolset" = Field(exclude=True)

    def _invoke(self, params: dict, context: ToolInvokeContext) -> StructuredToolResult:
        try:
            # Serialize calls to the same MCP server to prevent SSE conflicts
            # Different servers can still run in parallel
            if not self.toolset._mcp_config:
                raise ValueError("MCP config not initialized")

            lock = get_server_lock(str(self.toolset._mcp_config.get_lock_string()))
            with lock:
                return asyncio.run(self._invoke_async(params))
        except Exception as e:
            return StructuredToolResult(
                status=StructuredToolResultStatus.ERROR,
                error=str(e.args),
                params=params,
                invocation=f"MCPtool {self.name} with params {params}",
            )

    @staticmethod
    def _is_content_error(content: str) -> bool:
        try:  # aws mcp sometimes returns an error in content - status code != 200
            json_content: dict = json.loads(content)
            status_code = json_content.get("response", {}).get("status_code", 200)
            return status_code >= 300
        except Exception:
            return False

    async def _invoke_async(self, params: Dict) -> StructuredToolResult:
        async with get_initialized_mcp_session(self.toolset) as session:
            tool_result = await session.call_tool(self.name, params)

        merged_text = " ".join(c.text for c in tool_result.content if c.type == "text")
        return StructuredToolResult(
            status=(
                StructuredToolResultStatus.ERROR
                if (tool_result.isError or self._is_content_error(merged_text))
                else StructuredToolResultStatus.SUCCESS
            ),
            data=merged_text,
            params=params,
            invocation=f"MCPtool {self.name} with params {params}",
        )

    @classmethod
    def create(
        cls,
        tool: MCP_Tool,
        toolset: "RemoteMCPToolset",
    ):
        parameters = cls.parse_input_schema(tool.inputSchema)
        return cls(
            name=tool.name,
            description=tool.description or "",
            parameters=parameters,
            toolset=toolset,
        )

    @classmethod
    def parse_input_schema(
        cls, input_schema: dict[str, Any]
    ) -> Dict[str, ToolParameter]:
        required_list = input_schema.get("required", [])
        schema_params = input_schema.get("properties", {})
        parameters = {}
        for key, val in schema_params.items():
            parameters[key] = ToolParameter(
                description=val.get("description"),
                type=val.get("type", "string"),
                required=key in required_list,
            )

        return parameters

    def get_parameterized_one_liner(self, params: Dict) -> str:
        if params:
            if params.get("cli_command"):  # Return AWS MCP cli command, if available
                return f"{params.get('cli_command')}"

        if isinstance(self.toolset._mcp_config, MCPConfig):
            cmd = str(self.toolset._mcp_config.url)
        elif isinstance(self.toolset._mcp_config, StdioMCPConfig):
            cmd = self.toolset._mcp_config.command
        else:
            cmd = "unknown"
        return f"Call MCP Server ({cmd} - {self.name})"


class RemoteMCPToolset(Toolset):
    tools: List[RemoteMCPTool] = Field(default_factory=list)  # type: ignore
    icon_url: str = "https://registry.npmmirror.com/@lobehub/icons-static-png/1.46.0/files/light/mcp.png"
    _mcp_config: Optional[Union[MCPConfig, StdioMCPConfig]] = None

    def model_post_init(self, __context: Any) -> None:
        self.prerequisites = [
            CallablePrerequisite(callable=self.prerequisites_callable)
        ]

    @model_validator(mode="before")
    @classmethod
    def migrate_url_to_config(cls, values: dict[str, Any]) -> dict[str, Any]:
        """
        Migrates url from field parameter to config object.
        If url is passed as a parameter, it's moved to config (or config is created if it doesn't exist).
        """
        if not isinstance(values, dict) or "url" not in values:
            return values

        url_value = values.pop("url")
        if url_value is None:
            return values

        config = values.get("config")
        if config is None:
            config = {}
            values["config"] = config

        toolset_name = values.get("name", "unknown")
        if "url" in config:
            logging.warning(
                f"Toolset {toolset_name}: has two urls defined, remove the 'url' field from the toolset configuration and keep the 'url' in the config section."
            )
            return values

        logging.warning(
            f"Toolset {toolset_name}: 'url' field has been migrated to config. "
            "Please move 'url' to the config section."
        )
        config["url"] = url_value
        return values

    def prerequisites_callable(self, config) -> Tuple[bool, str]:
        try:
            if not config:
                return (False, f"Config is required for {self.name}")

            mode_value = config.get("mode", MCPMode.SSE.value)
            allowed_modes = [e.value for e in MCPMode]
            if mode_value not in allowed_modes:
                return (
                    False,
                    f'Invalid mode "{mode_value}", allowed modes are {", ".join(allowed_modes)}',
                )

            if mode_value == MCPMode.STDIO.value:
                self._mcp_config = StdioMCPConfig(**config)
            else:
                self._mcp_config = MCPConfig(**config)
                clean_url_str = str(self._mcp_config.url).rstrip("/")

                if self._mcp_config.mode == MCPMode.SSE and not clean_url_str.endswith(
                    "/sse"
                ):
                    self._mcp_config.url = AnyUrl(clean_url_str + "/sse")

            tools_result = asyncio.run(self._get_server_tools())

            self.tools = [
                RemoteMCPTool.create(tool, self) for tool in tools_result.tools
            ]

            if not self.tools:
                logging.warning(f"mcp server {self.name} loaded 0 tools.")

            return (True, "")
        except Exception as e:
            return (
                False,
                f"Failed to load mcp server {self.name}: {str(e)}",
            )

    async def _get_server_tools(self):
        async with get_initialized_mcp_session(self) as session:
            return await session.list_tools()

    def get_example_config(self) -> Dict[str, Any]:
        example_config = MCPConfig(
            url=AnyUrl("http://example.com:8000/mcp/messages"),
            mode=MCPMode.STREAMABLE_HTTP,
            headers={"Authorization": "Bearer YOUR_TOKEN"},
        )
        return example_config.model_dump()
